﻿// Copyright (c) .NET Foundation and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Roslynator.CSharp;

internal static class DocumentRefactoringFactory
{
    public static Func<CancellationToken, Task<Document>> ChangeTypeAndAddAwait(
        Document document,
        VariableDeclarationSyntax variableDeclaration,
        VariableDeclaratorSyntax variableDeclarator,
        ITypeSymbol newTypeSymbol,
        SemanticModel semanticModel,
        CancellationToken cancellationToken = default)
    {
        if (!newTypeSymbol.OriginalDefinition.EqualsOrInheritsFromTaskOfT())
            return default;

        if (semanticModel.GetEnclosingSymbol(variableDeclaration.SpanStart, cancellationToken) is not IMethodSymbol methodSymbol)
            return default;

        if (!methodSymbol.MethodKind.Is(MethodKind.Ordinary, MethodKind.LocalFunction))
            return default;

        SyntaxNode containingMethod = GetContainingMethod();

        if (containingMethod is null)
            return default;

        SyntaxNode bodyOrExpressionBody = GetBodyOrExpressionBody();

        if (bodyOrExpressionBody is null)
            return default;

        foreach (SyntaxNode descendant in bodyOrExpressionBody.DescendantNodes())
        {
            if (descendant is ReturnStatementSyntax returnStatement
                && returnStatement
                    .Expression?
                    .WalkDownParentheses()
                    .IsKind(SyntaxKind.AwaitExpression) == false)
            {
                return default;
            }
        }

        ITypeSymbol typeArgument = ((INamedTypeSymbol)newTypeSymbol).TypeArguments[0];

        return ct => DocumentRefactorings.ChangeTypeAndAddAwaitAsync(document, variableDeclaration, variableDeclarator, containingMethod, typeArgument, semanticModel, ct);

        SyntaxNode GetContainingMethod()
        {
            foreach (SyntaxReference syntaxReference in methodSymbol.DeclaringSyntaxReferences)
            {
                SyntaxNode syntax = syntaxReference.GetSyntax(cancellationToken);

                if (syntax.Contains(variableDeclaration))
                    return syntax;
            }

            return null;
        }

        SyntaxNode GetBodyOrExpressionBody()
        {
            switch (containingMethod.Kind())
            {
                case SyntaxKind.MethodDeclaration:
                    return ((MethodDeclarationSyntax)containingMethod).BodyOrExpressionBody();
                case SyntaxKind.LocalFunctionStatement:
                    return ((LocalFunctionStatementSyntax)containingMethod).BodyOrExpressionBody();
                default:
                    throw new InvalidOperationException();
            }
        }
    }
}
